P4248 [AHOI2013]差异

1i<jnlen(Ti)+len(Tj)2×lcp(Ti,Tj)\sum_{1\leqslant i <j\leqslant n}\text{len}(T_i)+\text{len}(T_j)-2\times\text{lcp}(T_i,T_j)

(1i<jnlen(Ti)+len(Tj))21i<jnlcp(Ti,Tj)\left(\sum_{1\leqslant i <j\leqslant n}\text{len}(T_i)+\text{len}(T_j)\right)-2\sum_{1\leqslant i <j\leqslant n}\text{lcp}(T_i,T_j)

(1i<jni+j)21i<jnlcp(sai,saj)\left(\sum_{1\leqslant i <j\leqslant n}i+j\right)-2\sum_{1\leqslant i <j\leqslant n}\text{lcp}(sa_i,sa_j)

前一个和式可以计算,不再赘述,考虑后一个和式。

我们知道:lcp(sai,saj)=minikj{heightk}\displaystyle\text{lcp}(sa_i,sa_j)=\min_{i \le k \le j} \{\text{height}_k\}

对于每一个 height\text{height} 用单调栈找到左右比它小的第一个元素的位置即可。

注意有一边严格小,另一边不大于,这样才不会重复遗漏。

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define LL long long

const int MAXN = 5e5;

int rk[ MAXN + 5 ] , sa[ MAXN + 5 ] , ht[ MAXN + 5 ];
int tmp[ MAXN + 5 ] , buc[ MAXN + 5 ];
void SA( int len , char *str ) {
    int s = 256;
    for( int i = 1 ; i <= s ; i ++ ) buc[ i ] = 0;
    for( int i = 1 ; i <= len ; i ++ ) buc[ rk[ i ] = str[ i ] ] ++;
    for( int i = 1 ; i <= s ; i ++ ) buc[ i ] += buc[ i - 1 ];
    for( int i = len ; i >= 1 ; i -- ) sa[ buc[ rk[ i ] ] -- ] = i;

    for( int w = 1 ; w <= len ; w <<= 1 ) {
        int cnt = 0;
        for( int i = len - w + 1 ; i <= len ; i ++ ) tmp[ ++ cnt ] = i;
        for( int i = 1 ; i <= len ; i ++ ) if( sa[ i ] > w ) tmp[ ++ cnt ] = sa[ i ] - w;

        for( int i = 1 ; i <= s ; i ++ ) buc[ i ] = 0;
        for( int i = 1 ; i <= len ; i ++ ) buc[ rk[ i ] ] ++;
        for( int i = 1 ; i <= s ; i ++ ) buc[ i ] += buc[ i - 1 ];
        for( int i = len ; i >= 1 ; i -- ) sa[ buc[ rk[ tmp[ i ] ] ] -- ] = tmp[ i ] , tmp[ i ] = 0;
        
        swap( tmp , rk );
        rk[ sa[ 1 ] ] = cnt = 1;
        for( int i = 2 ; i <= len ; i ++ ) rk[ sa[ i ] ] = ( tmp[ sa[ i ] ] == tmp[ sa[ i - 1 ] ] && tmp[ sa[ i ] + w ] == tmp[ sa[ i - 1 ] + w ] ) ? cnt : ++ cnt;
        if( cnt == len ) break; s = cnt;
    }

    for( int i = 1 , k = 0 ; i <= len ; i ++ ) {
        if( k ) k --;
        while( str[ sa[ rk[ i ] ] + k ] == str[ sa[ rk[ i ] - 1 ] + k ] ) k ++;
        ht[ rk[ i ] ] = k;
    }
}

int n; char str[ MAXN + 5 ];
int top , stk[ MAXN + 5 ] , l[ MAXN + 5 ] , r[ MAXN + 5 ];
LL Ans;
int main( ) {
	scanf("%s", str + 1 ); n = strlen( str + 1 );

    SA( n , str );
    // for( int i = 1 ; i <= n ; i ++ ) printf("%d%c", ht[ i ] , i == n ? '\n' : ' ' );
    ht[ stk[ top = 1 ] = 0 ] = -1;
    for( int i = 1 ; i <= n ; i ++ ) {
        while( top && ht[ i ] < ht[ stk[ top ] ] ) top --;
        l[ i ] = stk[ top ]; stk[ ++ top ] = i;
    }
    ht[ stk[ top = 1 ] = n + 1 ] = -1;
    for( int i = n ; i >= 1 ; i -- ) {
        while( top && ht[ i ] <= ht[ stk[ top ] ] ) top --;
        r[ i ] = stk[ top ]; stk[ ++ top ] = i;
    }
    // for( int i = 1 ; i <= n ; i ++ ) printf("%d %d\n", l[ i ] , r[ i ] );
    
    for( int i = 1 ; i <= n ; i ++ ) Ans -= 2ll * ht[ i ] * ( i - l[ i ] ) * ( r[ i ] - i );
    for( int i = n ; i >= 2 ; i -- ) Ans += 3ll * ( i - 1 ) * i / 2;
    printf("%lld\n", Ans );
    return 0;
}
/*
eededeedeedeedde
1798
*/